"""
Main file for training Yolo model on Pascal VOC dataset
"""
import argparse
import torch
import torchvision.transforms as transforms

from tqdm import tqdm
from torch.utils.data import DataLoader

from model.model import Yolov1
from model.loss import YoloLoss
from model.dataset import VOCDataset
from model.utils import (
    get_bboxes,
    mean_average_precision,
)

seed = 123
torch.manual_seed(seed)

# Hyperparameters etc.
NUM_WORKERS = 2
PIN_MEMORY = True
LAST_MODEL_FILE = "last.pth.tar"


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, bboxes):
        for t in self.transforms:
            img, bboxes = t(img), bboxes

        return img, bboxes


transform = Compose(
    [
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
    ]
)


def test_fn(test_loader, model, loss_fn, device):
    loop = tqdm(test_loader, leave=True)
    mean_loss = []

    for batch_idx, (x, y) in enumerate(loop):
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = loss_fn(out, y)
        mean_loss.append(loss.item())

        # update progress bar
        loop.set_postfix(loss=loss.item())

    print(f"Mean loss was {sum(mean_loss)/len(mean_loss)}")


def main(args):
    model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(args.device)
    loss_fn = YoloLoss()

    checkpoint = torch.load(LAST_MODEL_FILE)
    model.load_state_dict(checkpoint["state_dict"])
    print(f"Loaded model: {LAST_MODEL_FILE}")

    test_dataset = VOCDataset(
        "data/test.txt",
        transform=transform,
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=True,
        drop_last=True,
    )

    model.eval()
    test_fn(test_loader, model, loss_fn, args.device)

    pred_boxes, target_boxes = get_bboxes(
        test_loader, model, iou_threshold=0.5, threshold=0.4, device=args.device
    )

    mean_avg_prec = mean_average_precision(
        pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint"
    )
    print(f"Test mAP: {mean_avg_prec}")


def parse_args():
    parser = argparse.ArgumentParser(description="Yolo-v1 train")

    # fmt: off
    parser.add_argument("--batch-size", type=int, default=16, help="batch size")
    parser.add_argument('--device', default='cuda:0', help='Device used for inference')
    # fmt: on

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)
